import threading
import torch


def move_to_cpu(obj):
    if isinstance(obj, torch.Tensor):
        return obj.cpu()
    elif isinstance(obj, dict):
        return {k: move_to_cpu(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return type(obj)(move_to_cpu(v) for v in obj)
    else:
        return obj


class PrefetchMicroBatchIterator:
    def __init__(self, mbs):
        self.mbs = mbs
        self.n = len(mbs)
        self.idx = 0
        self.cpu_cache = [None] * self.n
        self.cpu_cache[0] = mbs[0]

        if self.n > 1:
            self.prefetch_thread = threading.Thread(target=self._prefetch_to_cpu, daemon=True)
            self.prefetch_thread.start()
        else:
            self.prefetch_thread = None

    def _prefetch_to_cpu(self):
        if torch.npu.is_available():
            torch.npu.set_device(torch.npu.current_device())
        for i in range(1, self.n):
            self.cpu_cache[i] = move_to_cpu(self.mbs[i])

    def __iter__(self):
        return self

    def __next__(self):
        if self.idx >= self.n:
            self.close()
            raise StopIteration

        if self.idx == 0:
            data = self.cpu_cache[0]
        else:
            if self.idx == 1 and self.prefetch_thread is not None and self.prefetch_thread.is_alive():
                self.prefetch_thread.join()
            data = self.cpu_cache[self.idx]

        self.idx += 1
        return data

    def close(self):
        if self.prefetch_thread is not None and self.prefetch_thread.is_alive():
            self.prefetch_thread.join()
        self.cpu_cache = None
        self.mbs = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()